import random

import numpy as np
import torch

from centralized_verification.agents.decentralized_training.independent_agents.deep_q_learner import DeepQLearner
from centralized_verification.agents.decentralized_training.independent_agents.tabular_q_learner import TabularQLearner
from centralized_verification.agents.decentralized_training.multi_agent_wrapper import MultiAgentLearnerWrapper
from centralized_verification.agents.utils import linear_epsilon_anneal_episodes, linear_epsilon_anneal_steps
from centralized_verification.configuration import TestingLimits, Configuration, TrainingLimits, TestConfiguration
from centralized_verification.envs.continuous_grid_world import ContinuousGridWorld
from centralized_verification.envs.fast_grid_world import FastGridWorld
from centralized_verification.envs.fast_grid_world_2d_obs import FastGridWorld2DObs
from centralized_verification.envs.fast_grid_world_nearby_obs import FastGridWorldNearbyObs
from centralized_verification.envs.fast_grid_world_partial_obs import FastGridWorldPartialObs
from centralized_verification.envs.particle_momentum import ParticleMomentum
from centralized_verification.envs.utils import map_parser
from centralized_verification.models.norm_simple_mlp import NormSimpleMLP
from centralized_verification.models.simple_cnn import SimpleCNN
from centralized_verification.models.simple_mlp import SimpleMLP
from centralized_verification.shields.centralized_shield import CentralizedShieldOracle
from centralized_verification.shields.decentralized_shield import DecentralizedShieldOracle
from centralized_verification.shields.no_shield import NoShield
from centralized_verification.shields.slugs_shielding.combine_identical_states import load_centralized_shield, \
    load_decentralized_shield
from centralized_verification.shields.slugs_shielding.label_extractor import FullObsGridWorldLabelExtractor2Agents, \
    ParticleMomentumLabelExtractor2Agents, ContinuousGridWorldLabelExtractor
from centralized_verification.shields.slugs_shielding.slugs_centralized_shield import SlugsCentralizedShield
from centralized_verification.shields.slugs_shielding.slugs_decentralized_shield import SlugsDecentralizedShield
from centralized_verification.train import train_loop, test_loop
from centralized_verification.training_state import maybe_load_from_checkpoint
from experiments.utils.parallel_run import parse_args_and_run_parallel_csv_experiment

shields = {
    "decentralized": DecentralizedShieldOracle,
    "centralized": CentralizedShieldOracle,
    "none": NoShield,
    "slugs_centralized": SlugsCentralizedShield,
    "slugs_decentralized": SlugsDecentralizedShield
}


def is_true(bool_ish):
    return bool_ish in (True, "1", 1, "True", "T", "Yes", "Y")


def get_config_from_params(params):
    seed = int(params["seed"])
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)

    run_name = params["run_name"]
    randomize_starts = is_true(params.get("randomize_starts", False))
    max_total_steps = params.get("max_total_steps")
    if max_total_steps is not None:
        max_total_steps = int(max_total_steps)
    max_num_episodes = params.get("max_num_episodes")
    if max_num_episodes is not None:
        max_num_episodes = int(max_num_episodes)

    if params["map_type"] in ("GridWorld", "ContinuousGridWorld"):
        map_name = params["grid_world_map_name"]
        env_spec = map_parser(f"maps/{map_name}.txt")
        collision_reward = float(params.get("grid_world_collision_penalty", -10))
        agents_bounce = is_true(params.get("grid_world_agents_bounce", False))
        terminate_on_collision = is_true(params.get("grid_world_terminate_on_collision", False))

        env_params = {
            "randomize_starts": randomize_starts,
            "collision_reward": collision_reward,
            "agents_bounce": agents_bounce,
            "terminate_on_collision": terminate_on_collision
        }

        if params["map_type"] == "ContinuousGridWorld":
            env = ContinuousGridWorld(*env_spec, **env_params)
        elif params["grid_world_obs_type"] == "PartialObsDiscrete":
            env = FastGridWorldPartialObs(*env_spec, **env_params)
        elif params["grid_world_obs_type"] == "FullObsDiscrete":
            env = FastGridWorld(*env_spec, **env_params)
        elif params["grid_world_obs_type"] == "NearbyObsDiscrete":
            env = FastGridWorldNearbyObs(*env_spec, **env_params)
        elif params["grid_world_obs_type"] == "2DObs":
            if "grid_world_obs_radius" in params:
                env_params.update({
                    "other_agent_obs_radius": int(params["grid_world_obs_radius"])
                })
            env = FastGridWorld2DObs(*env_spec, **env_params)
        else:
            raise NotImplemented
    elif params["map_type"] == "ParticleMomentum":
        world_size = int(params["particle_world_size"])
        collision_reward = float(params["particle_collision_penalty"])
        terminate_on_collision = is_true(params["particle_terminate_on_collision"])
        agents_observe_momentum = is_true(params["particle_agents_observe_momentum"])

        env_params = {
            "world_size": world_size,
            "agents_observe_momentum": agents_observe_momentum,
            "randomize_starts": randomize_starts,
            "collision_reward": collision_reward,
            "terminate_on_collision": terminate_on_collision
        }

        env = ParticleMomentum(**env_params)
    else:
        raise NotImplemented

    if max_total_steps is None:
        eps_scheduler = linear_epsilon_anneal_episodes(float(params["learner_anneal_eps_start"]),
                                                       float(params["learner_anneal_eps_finish"]),
                                                       max_num_episodes)
    else:
        eps_scheduler = linear_epsilon_anneal_steps(float(params["learner_anneal_eps_start"]),
                                                    float(params["learner_anneal_eps_finish"]),
                                                    max_total_steps)

    learner_params = {"epsilon_scheduler": eps_scheduler,
                      "evaluation_epsilon": float(params.get("learner_evaluation_epsilon", 0))}

    if params["learner_type"] == "Individual_Q":
        agents = list(TabularQLearner(obs_space, action_space, **learner_params) for obs_space, action_space in
                      zip(env.agent_obs_spaces(), env.agent_actions_spaces()))

        multi_agent = MultiAgentLearnerWrapper(agents)
    elif params["learner_type"] == "Individual_Deep_Q":
        nn_model_class = {
            "simple_mlp": SimpleMLP,
            "norm_simple_mlp": NormSimpleMLP,
            "simple_cnn": SimpleCNN
        }[params["learner_deep_network_model"]]

        learner_params["model_class"] = nn_model_class
        learner_params["make_multidiscrete_one_hot"] = is_true(params.get("learner_transform_one_hot"))

        agents = list(DeepQLearner(obs_space, action_space, **learner_params) for obs_space, action_space in
                      zip(env.agent_obs_spaces(), env.agent_actions_spaces()))
        multi_agent = MultiAgentLearnerWrapper(agents)
    else:
        raise NotImplemented

    label_extractor = None
    if params["map_type"] == "GridWorld" and params["grid_world_obs_type"] == "FullObsDiscrete":
        label_extractor = FullObsGridWorldLabelExtractor2Agents(env)
    elif params["map_type"] == "ContinuousGridWorld":
        label_extractor = ContinuousGridWorldLabelExtractor(env)
    elif params["map_type"] == "ParticleMomentum":
        label_extractor = ParticleMomentumLabelExtractor2Agents(env)

    other_shield_params = {}
    if params["shield"] in ["slugs_centralized", "slugs_decentralized"]:
        if label_extractor is None:
            raise Exception("No label extractor was defined for current map type")
        other_shield_params["label_extractor"] = label_extractor

        if params["shield"] == "slugs_centralized":
            other_shield_params["shield_spec"] = load_centralized_shield(params["shield_specification"])
        elif params["shield"] == "slugs_decentralized":
            other_shield_params["shield_spec"] = load_decentralized_shield(params["shield_specification"])
            other_shield_params["random_agent_order"] = is_true(
                params.get("shield_decentralized_random_agent_order", True))

    punish_unsafe_orig_action = is_true(params.get("punish_unsafe_orig_action", False))
    if punish_unsafe_orig_action:
        shield = shields[params["shield"]](punish_unsafe_orig_action=True, unsafe_action_punishment=float(
            params.get("punish_unsafe_orig_action_modifier")), env=env, **other_shield_params)
    else:
        shield = shields[params["shield"]](punish_unsafe_orig_action=False, env=env, **other_shield_params)

    if "evaluation_shield" in params:
        other_eval_shield_params = {}
        if params["evaluation_shield"] in ["slugs_centralized", "slugs_decentralized"]:
            if label_extractor is None:
                raise Exception("No label extractor was defined for current map type")
            other_eval_shield_params["label_extractor"] = label_extractor

            if params["evaluation_shield"] == "slugs_centralized":
                other_eval_shield_params["shield_spec"] = load_centralized_shield(
                    params["evaluation_shield_specification"])
            elif params["evaluation_shield"] == "slugs_decentralized":
                other_eval_shield_params["shield_spec"] = load_decentralized_shield(
                    params["evaluation_shield_specification"])
                other_eval_shield_params["random_agent_order"] = is_true(
                    params.get("evaluation_shield_decentralized_random_agent_order", True))
        evaluation_shield = shields[params["evaluation_shield"]](punish_unsafe_orig_action=False, env=env,
                                                                 **other_eval_shield_params)
    else:
        evaluation_shield = shield

    config = Configuration(
        shield=shield,
        env=env,
        learner=multi_agent,
        run_name=run_name,
        limits=TrainingLimits(max_episode_len=500, max_total_steps=max_total_steps, max_num_episodes=max_num_episodes),
        num_log_entries=200,
        num_checkpoints=10,
    )

    evaluation_config = TestConfiguration(
        shield=evaluation_shield,
        env=env,
        agent=multi_agent,
        run_name=params.get("evaluation_run_name", run_name),
        limits=TestingLimits(max_episode_len=500, num_episodes=100)
    )

    return config, evaluation_config


def run_with_params(params):
    # TODO Set from environment variable
    torch.set_num_threads(4)
    torch.set_num_interop_threads(4)
    config, test_config = get_config_from_params(params)

    checkpoint = maybe_load_from_checkpoint(config.run_name)

    if checkpoint:
        config.learner.load_state_dict(checkpoint.learner_state_dict)

    skip_training = is_true(params.get("skip_training", False))
    if not skip_training:
        train_loop(config, checkpoint)

    skip_evaluation = is_true(params.get("skip_evaluation", False))
    if not skip_evaluation:
        test_loop(test_config)


if __name__ == '__main__':
    parse_args_and_run_parallel_csv_experiment(run_with_params, [])
